import tqdm
import json
import copy
import os
import logging
import random
import numpy as np
import pdb
import torch.nn as nn
import sys
import torch
import transformers
from dataclasses import dataclass, field
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
import global_vars
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import CyclicLR
from transformers import AutoModelForCausalLM

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        input_data = sample['train_data_input']
        label = sample['train_data_output']
        return (
        input_data.squeeze(0) if input_data is not None else None,
        sample["train_attention_mask"].squeeze(0) if sample["train_attention_mask"] is not None else None,
        sample["train_position_ids"].squeeze(0) if sample["train_position_ids"] is not None else None,
        sample["train_past_key_value"].squeeze(0) if sample["train_past_key_value"] is not None else None,
        sample["train_output_attentions"],
        sample["train_use_cache"].squeeze(0) if sample["train_use_cache"] is not None else None,
        sample["train_cache_position"].squeeze(0) if sample["train_cache_position"] is not None else None,
        tuple(x.squeeze(0) if x is not None else None for x in sample["train_position_embeddings"]) if sample["train_position_embeddings"] is not None else None,
        label.squeeze(0) if label is not None else None
    )

def custom_collate_fn(batch):
    filtered_batch = []
    for item in batch:
        filtered_item = tuple(
            x if x is not None else torch.tensor([]) for x in item
        )
        filtered_batch.append(filtered_item)
    return torch.utils.data.dataloader.default_collate(filtered_batch)

# _zh
begin_index = 28
end_index = 29
model_name = 'Llama-3.1-8B'
path = '/root/SLEB/_zh/data/' + model_name + str(begin_index) + '-' + str(end_index) + '.pt'
train_data = torch.load(path, map_location = "cpu")

train_dataset = CustomDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)

def main():
    # _zh
    model_name = '/root/model/Llama-3.1-8B'
    model = AutoModelForCausalLM.from_pretrained(model_name, 
                                                torch_dtype='auto',
                                                low_cpu_mem_usage=True,
                                                device_map='cuda:0',
                                                attn_implementation="eager"
                                                )
    model.seqlen = 2048
    model.name = model_name
    model = model.model.layers[begin_index]
    model = model.float()
    model.train()
    name_list = ['self_attn.q_proj.weight', 'self_attn.q_proj.bias', 'self_attn.k_proj.weight', 'self_attn.k_proj.bias', 'self_attn.v_proj.weight', 'self_attn.v_proj.bias', 'input_layernorm.weight']
    for name, param in model.named_parameters():
        if name in name_list:
            param.requires_grad = False
    for name, param in model.named_parameters():
        print('{} : {}'.format(name, param.requires_grad))
    
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr = 5e-6)
    num_epochs = 100
    num_batch = 4096 * 3
    now_batch = 0
    for epoch in range(num_epochs):
        running_loss = 0.0
        for inputs, train_attention_mask, train_position_ids, train_past_key_value, train_output_attentions, train_use_cache, train_cache_position, train_position_embeddings, labels in train_loader:
            now_batch += 1
            inputs = inputs.cuda()
            train_attention_mask = train_attention_mask.cuda()
            train_position_ids = train_position_ids.cuda()
            train_past_key_value = train_past_key_value.cuda()
            train_output_attentions = train_output_attentions.cuda()
            train_use_cache = train_use_cache.cuda()
            train_cache_position = train_cache_position.cuda()
            if isinstance(train_position_embeddings, (list, tuple)):
                train_position_embeddings = [tensor.cuda() for tensor in train_position_embeddings]
            labels = labels.cuda()
            optimizer.zero_grad()
            outputs = model(inputs,
                        attention_mask=None,
                        position_ids=train_position_ids,
                        past_key_value=None,
                        output_attentions=False,
                        use_cache=None,
                        cache_position=train_cache_position,
                        position_embeddings=train_position_embeddings,)
            outputs = outputs[0]
            loss = criterion(outputs, labels)
            # print('loss : {}'.format(loss.data))
            if now_batch >= num_batch:
                # _zh
                name = 'Llama-3.1-8B'
                torch.save(model.state_dict(), '/root/SLEB/_zh/model/' + name + str(begin_index) + '-' + str(end_index) + '.pt')
                sys.exit()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch [{epoch}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")


if __name__ == "__main__":
    main()